{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# PPO for Portfolio Management\n",
    "This tutorial is to demonstrate an example of using PPO to do portfolio management"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step1: Import Packages"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import sys\n",
    "from pathlib import Path\n",
    "import os\n",
    "import torch\n",
    "\n",
    "ROOT = str(Path(__file__).resolve().parents[2])\n",
    "sys.path.append(ROOT)\n",
    "\n",
    "import argparse\n",
    "import os.path as osp\n",
    "from mmcv import Config\n",
    "from trademaster.utils import replace_cfg_vals\n",
    "from trademaster.nets.builder import build_net\n",
    "from trademaster.environments.builder import build_environment\n",
    "from trademaster.datasets.builder import build_dataset\n",
    "from trademaster.agents.builder import build_agent\n",
    "from trademaster.optimizers.builder import build_optimizer\n",
    "from trademaster.losses.builder import build_loss\n",
    "from trademaster.trainers.builder import build_trainer"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step2: Import Configs"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def parse_args():\n",
    "    parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n",
    "    parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"portfolio_management\", \"portfolio_management_exchange_ppo_ppo_adam_mse.py\"),\n",
    "                        help=\"download datasets config file path\")\n",
    "    parser.add_argument(\"--task_name\", type=str, default=\"train\")\n",
    "    parser.add_argument(\"--test_style\", type=str, default=\"-1\")\n",
    "    args = parser.parse_args()\n",
    "    return args\n",
    "\n",
    "args = parse_args()\n",
    "cfg = Config.fromfile(args.config)\n",
    "task_name = args.task_name\n",
    "cfg = replace_cfg_vals(cfg)\n",
    "print(cfg)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step3: Build Dataset"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "dataset = build_dataset(cfg)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step4: Build Trainer"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "work_dir = os.path.join(ROOT, cfg.trainer.work_dir)\n",
    "\n",
    "if not os.path.exists(work_dir):\n",
    "    os.makedirs(work_dir)\n",
    "cfg.dump(osp.join(work_dir, osp.basename(args.config)))\n",
    "\n",
    "trainer = build_trainer(cfg, default_args=dict(dataset=dataset, device = device))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step5: Train, Valid and Test"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "if task_name.startswith(\"train\"):\n",
    "    trainer.train_and_valid()\n",
    "    print(\"train end\")\n",
    "elif task_name.startswith(\"test\"):\n",
    "    trainer.test()\n",
    "    print(\"test end\")"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
